# Packages 
from astropy.io import fits
from astropy import wcs
from astropy.coordinates import SkyCoord
from aplpy import FITSFigure
import math
import numpy as np
from scipy import stats, optimize

# ===
# Functions to calculate the orientation of the filaments
# ===

# Main function
def FitOrientation(FitsFile):
    # Finding all the skeleton coordinates
    PosArray = PositionList(FitsFile)
    # Fitting a line to the data
    LineParameters = LinearFit(PosArray)
    # Calculate the angle using SkyCoords
    SkeletonAngle = CalculateAngle(PosArray,LineParameters)
    return SkeletonAngle, LineParameters, PosArray

# Function to list all the positions in the skeleton
def PositionList(FitsFile):
    # Getting the WCS info from the FITS file
    WCSinfo = wcs.WCS(FitsFile[0].header)
    # Measuring the size of the datar array
    data_size = np.shape(FitsFile[0].data)
    # Lists for the coordinates
    ListX = []
    ListY = []
    # Finding the skeleton positions
    for i in range(0,data_size[1]):
        for j in range(0,data_size[0]):
            # Establishing if it is a skeleton position
            if FitsFile[0].data[j,i] > 0.0: # PIXEL COORDINATES ARE INVERTED IN PYTHON
			    # Finding the galactic longitude and latitude in degrees
                PixCoord = SkyCoord.from_pixel(i,j, WCSinfo)
                l = PixCoord.l.deg
                b = PixCoord.b.deg
                # Adding the coordinates to the list
                ListX.append(l)
                ListY.append(b)
    # Transforming the lists into a single numpy array
    PosArray = np.array([ListX,ListY])
    return PosArray

# Functions to fit a line to the skeleton positions
def LinearFit(PosArray):
    # Defining linear function to be fitted by curve_fit
    def Line(x, a, b):
        return a*x + b
    # Fitting the line to the data
    LineParameters, Lcov = optimize.curve_fit(Line, PosArray[0,:], PosArray[1,:])
    return LineParameters

# Calculating the angle relative to galactic North
def CalculateAngle(PosArray,LineParameters):
    # Create the predicted positions from the fit parameters
    Predicted = np.copy(PosArray)
    Predicted[1,:] = LineParameters[0]*Predicted[0,:] + LineParameters[1]
    # Finding the extremeties of the line
    IndexLmax = Predicted[0,:].argmax()
    IndexLmin = Predicted[0,:].argmin()
    # Creating SkyCoords of the extremities
    SkyCoordMax = SkyCoord(Predicted[0,IndexLmax],Predicted[1,IndexLmax],frame='galactic', unit='deg')
    SkyCoordMin = SkyCoord(Predicted[0,IndexLmin],Predicted[1,IndexLmin],frame='galactic', unit='deg')
    # Calculating the angle from min to max
    RefAngle = SkyCoordMin.position_angle(SkyCoordMax)
    SkeletonAngle = RefAngle.deg
    return SkeletonAngle

# =============================================================================
# Skeleton on Column Density Map
# =============================================================================

def SkeletonMap(Nref, Iref, ScaleLength, PosArray, LineParameters, Region, FigSize, Ilevels, Nscale, Colorbar='top', BeamPos='bottom left', 
                Scalebar='top right', Linewidth=2.0, HAWCBeam=0.00506, 
                FilLabel=None, FilLabelX=0.5, FilLabelY=0.95):
    # Creating a copy of the column density FITS container to have a cleaner color scale label
    Narray = Nref.data.copy()
    Narray = Narray*10**-22
    Nfits = fits.PrimaryHDU(data=Narray, header=Nref.header)
    # Plotting the COLUMN DENSITY MAGNETIC FIELD MAP
    N_plot = FITSFigure(Nfits,figsize=(FigSize[0], FigSize[1]))
    N_plot.tick_labels.set_xformat('dd.d')
    N_plot.tick_labels.set_yformat('d.dd')
    N_plot.show_colorscale(cmap='Greys', vmin=Nscale[0]*10**-22, vmax=Nscale[1]*10**-22)
    N_plot.add_colorbar(pad=0.25)
    N_plot.colorbar.set_location(Colorbar)
    N_plot.colorbar.set_axis_label_text(r'$\it{Herschel}$ Column Density $N_{H_2}$ ($10^{22}$ cm$^{-2}$)')
    N_plot.add_scalebar(ScaleLength, '1 pc', corner=Scalebar, frame=True) # Scalebar equivalent for 1 pc 
    # Adding beams
    # Herschel
    N_plot.add_beam(facecolor='green', edgecolor='green', zorder=1,
                        linewidth=2, pad=1, corner=BeamPos) # Herschel Beam
    N_plot.beam.set_major(0.01011) # Update major axis for Herschel beam (36.4'', see Aniano et al. 2011)
    N_plot.beam.set_minor(0.01011) # Update minor axis for Herschel beam 
        # HAWC+
    N_plot.add_beam(facecolor='red', edgecolor='red',
                    linewidth=2, pad=1, corner=BeamPos) # HAWC+ Beam
    N_plot.beam[1].set_major(HAWCBeam) # Update major axis for HAWC+ Beam
    N_plot.beam[1].set_minor(HAWCBeam) # Update minor axis for HAWC+ Beam

    # Adding total intensity contour
    N_plot.show_contour(data=Iref, colors='white', levels=Ilevels)
    # Centering
    N_plot.recenter(x=Region[0], y=Region[1], width=Region[2], height=Region[3])

    # Adding the spine of the filament
    N_plot.show_markers(PosArray[0,:],PosArray[1,:],edgecolor='black',facecolor='dodgerblue',marker='o',s=50, zorder=3,linewidth=0.4)

    # Creating predicted Y positions
    PosArrayCopy = np.copy(PosArray)
    # xw, yw = line[0, :], line[1, :] # from AplPy show_list
    LineArray = np.zeros((2,2))
    LineArray[:,0] = PosArrayCopy[:,0]
    LineArray[:,1] = PosArrayCopy[:,-1]
    LineArray[1,:] = LineParameters[0]*LineArray[0,:] + LineParameters[1]
    LineList = [LineArray]
    # Plotting the line
    N_plot.show_lines(LineList, color='red', linewidth=2.5)


    # Adding label for the filament
    if FilLabel != None:
        N_plot.add_label(FilLabelX, FilLabelY, FilLabel, relative=True, size=13)


    return N_plot